#!/usr/bin/env python3
from __future__ import annotations
import argparse, pathlib, sys, os, re, importlib.util, inspect, csv
from typing import Any, Dict, List
import numpy as np

MODULE_DIR = pathlib.Path(__file__).resolve().parent
ROOT       = MODULE_DIR.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

def load_yaml(path: str) -> Dict[str, Any]:
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def ensure_dir(p: pathlib.Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def save_csv_row(csv_path: pathlib.Path, row: Dict[str, Any], header: List[str]) -> None:
    need_hdr = not csv_path.exists()
    with open(csv_path, "a", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=header)
        if need_hdr: w.writeheader()
        for h in header: row.setdefault(h, None)
        w.writerow(row)

def to_abs(rel_or_abs: str) -> str:
    return str((ROOT / rel_or_abs).resolve()) if not os.path.isabs(rel_or_abs) else rel_or_abs

def kernel_path_from_tpl(tpl: str, gauge: str, L: int) -> str:
    return to_abs(tpl.format(gauge=gauge, L=L))

def flip_counts_path_from_tpl(tpl: str, L: int) -> str:
    return to_abs(tpl.format(L=L))

def import_module_from_path(p: pathlib.Path, modname: str):
    spec = importlib.util.spec_from_file_location(modname, str(p))
    if not spec or not spec.loader: return None
    mod = importlib.util.module_from_spec(spec)  # type: ignore
    sys.modules[spec.name] = mod
    spec.loader.exec_module(mod)  # type: ignore
    return mod

def find_compute_callable(cfg: Dict[str, Any]):
    cxa = dict(cfg.get("crossover_analysis", {}))
    explicit_path = cxa.get("compute_module")
    explicit_func = cxa.get("compute_function", "compute_sigma_c")

    candidates: List[pathlib.Path] = []
    if explicit_path:
        candidates.append(pathlib.Path(to_abs(explicit_path)))
    roots = [MODULE_DIR / "orig", ROOT / "orig", MODULE_DIR]
    for r in roots:
        if r.exists():
            for p in r.rglob("*.py"):
                try:
                    txt = p.read_text(encoding="utf-8", errors="ignore")
                except Exception:
                    continue
                if re.search(r"\bdef\s+%s\s*\(" % re.escape(explicit_func), txt):
                    candidates.append(p)

    seen = set()
    for p in candidates:
        if not p.exists() or p in seen: continue
        seen.add(p)
        mod = import_module_from_path(p, modname=f"co_mod_{abs(hash(p))}")
        if not mod: continue
        if hasattr(mod, explicit_func):
            fn = getattr(mod, explicit_func)
            sig = inspect.signature(fn)
            return fn, sig

    raise RuntimeError("Could not locate crossover compute function; set crossover_analysis.compute_module/function in YAML.")

def main() -> None:
    ap = argparse.ArgumentParser(description="Crossover analysis runner (σ_c proxy with jackknife)")
    ap.add_argument("--config","-c", required=True)
    ap.add_argument("--output-dir","-o", default="data/results/vol4_wilson_loop_pipeline_crossover_analysis")
    args = ap.parse_args()

    cfg = load_yaml(args.config)
    out_dir = pathlib.Path(to_abs(args.output_dir)); ensure_dir(out_dir)
    out_csv = out_dir / "crossover_summary.csv"

    cxa = dict(cfg.get("crossover_analysis", {}))
    gauges = cxa.get("gauge_groups", ["SU2","SU3"])
    Ls = list(cxa.get("volumes", [])) or list(cfg.get("L_values", [])) or list(cfg.get("L_list", []))
    if not Ls: raise ValueError("No L list found (crossover_analysis.volumes or L_values/L_list).")
    bs  = list(cfg.get("b_values", []))
    ks  = list(cfg.get("k_values", []))
    n0s = list(cfg.get("n0_values", []))
    if not (bs and ks and n0s):
        raise ValueError("Missing b_values / k_values / n0_values in config.")

    # kernels
    if isinstance(cfg.get("kernel_path_template"), dict):
        kernel_tpl_by_g = cfg["kernel_path_template"]
    else:
        ktpl = cxa.get("kernel_path_template")
        if not ktpl: raise ValueError("crossover_analysis.kernel_path_template missing.")
        kernel_tpl_by_g = {g: ktpl for g in gauges}

    flip_tpl = cfg.get("flip_counts_path_template")
    if not flip_tpl:
        raise ValueError("flip_counts_path_template missing in config.")

    pivot = dict(cxa.get("pivot", {})) or dict(cfg.get("pivot", {}))
    for key in ("a","b","k","n0","beta","logistic_k","logistic_n0"):
        pivot.setdefault(key, None)

    # optional fixed loop window from YAML
    loop_sizes_cfg = cxa.get("loop_sizes", None)
    compute_fn, compute_sig = find_compute_callable(cfg)

    header = ["L","gauge","b","k","n0","sigma_c","sigma_c_err","kernel_path","flip_counts_path","ok"]

    for L in sorted(int(L) for L in Ls):
        exp_links = 2 * L * L
        for g in gauges:
            kernel_path = kernel_path_from_tpl(kernel_tpl_by_g[g], g, L)
            if not os.path.exists(kernel_path):
                raise FileNotFoundError(f"Kernel not found: {kernel_path}")
            K = np.load(kernel_path)
            if K.ndim == 1:
                if K.size != exp_links:
                    raise ValueError(f"Kernel length {K.size} != 2*L^2 ({exp_links}) for L={L}, gauge={g}")
            elif K.ndim == 3:
                if K.shape[0] != exp_links:
                    raise ValueError(f"Kernel shape {K.shape} first dim != 2*L^2 ({exp_links}) for L={L}, gauge={g}")
            else:
                raise ValueError(f"Unsupported kernel shape {K.shape} for {kernel_path}")

            for b in bs:
                for k in ks:
                    for n0 in n0s:
                        flip_path = flip_counts_path_from_tpl(flip_tpl, L)
                        if not os.path.exists(flip_path):
                            raise FileNotFoundError(f"Flip counts not found: {flip_path}")

                        kwargs = dict(
                            L=int(L), gauge=str(g),
                            kernel_path=kernel_path, flip_counts_path=flip_path,
                            pivot=pivot, b=float(b), k=float(k), n0=float(n0),
                            job_tag=f"L{L}_{g}_b{b}_k{k}_n0{n0}",
                            work_dir=str(out_dir),
                        )
                        if "loop_sizes" in compute_sig.parameters and loop_sizes_cfg is not None:
                            kwargs["loop_sizes"] = loop_sizes_cfg

                        result = compute_fn(**kwargs)
                        if isinstance(result, (tuple, list)) and len(result) >= 2:
                            sigma_c, sigma_c_err = float(result[0]), float(result[1])
                        else:
                            sigma_c, sigma_c_err = float(result), None

                        save_csv_row(out_csv, {
                            "L": L, "gauge": g, "b": b, "k": k, "n0": n0,
                            "sigma_c": sigma_c,
                            "sigma_c_err": sigma_c_err,
                            "kernel_path": os.path.relpath(kernel_path, ROOT),
                            "flip_counts_path": os.path.relpath(flip_path, ROOT),
                            "ok": True,
                        }, header)

                        err_txt = f" ±{sigma_c_err:.3g}" if sigma_c_err is not None else ""
                        print(f"✅ crossover {g} L={L} b={b} k={k} n0={n0}  sigma_c={sigma_c:.6g}{err_txt}")

    print(f"\nSummary CSV → {out_csv}")

if __name__ == "__main__":
    main()
